import torch
import numpy as np
import click
from bgflow.utils import (
    distance_vectors,
    distances_from_vectors,
)
from bgflow import (
    DiffEqFlow,
    BoltzmannGenerator,
    MeanFreeNormalDistribution,
    BlackBoxDynamics,
)
from bgflow.utils import assert_numpy
from bgflow.bg import sampling_efficiency
import tqdm
from eq_ot_flow.estimator import BruteForceEstimatorFast
from path_grad_helpers import (
    device,
    load_weights,
)
import json

from bgmol.datasets import AImplicitUnconstrained
import mdtraj as md
from bgflow.utils import (
    as_numpy,
)
from eq_ot_flow.utilities import (
    create_adjacency_list,
)
from bgflow import XTBEnergy, XTBBridge
from eq_ot_flow.models import EGNN_dynamics_AD2_cat
import networkx.algorithms.isomorphism as iso
import networkx as nx
from networkx import isomorphism
import scipy


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


def compute_foward_ess(log_weight_hist):
    """Forward ESS."""
    log_weight_hist = log_weight_hist.flatten()
    with torch.no_grad():
        # weight_hist = calc_imp_weights(log_weight_hist)
        max_log_uw, _ = log_weight_hist.max(-1)
        w = torch.exp(log_weight_hist - max_log_uw)
        inv_w = 1 / w

    ess = 1 / w.mean() / inv_w.mean()

    return ess.item()


@click.command()
@click.option("--path", type=str, default="models")
@click.option("--data_path", default="/data")
@click.option("--n_sample_batches", default=20)
@click.option("--batch-size", default=10000)
@click.option("--n_knots_hutch", default=20)
@click.option("--xtb", default=True)
@click.option("--transferable", default=False)
@click.option("--id", default="")
@click.option("--sample_q", default=True)
@click.option("--sample_p", default=True)
def main(
    path,
    n_sample_batches,
    data_path,
    batch_size,
    xtb,
    transferable,
    id,
    sample_q,
    sample_p,
):
    print(f"Loading dataset - xtb: {xtb}")
    dataset = AImplicitUnconstrained(read=True)

    n_particles = 22
    n_dimensions = 3
    dim = n_particles * n_dimensions

    target = dataset.get_energy_model()

    ala_traj = md.Trajectory(dataset.xyz, dataset.system.mdtraj_topology)

    if transferable:
        # atom types for tbg
        atom_types = np.arange(22)
        atom_types[[1, 2, 3]] = 2
        atom_types[[19, 20, 21]] = 20
        atom_types[[11, 12, 13]] = 12
    else:
        atom_dict = {"H": 0, "C": 1, "N": 2, "O": 3}
        atom_types = []
        for atom_name in ala_traj.topology.atoms:
            atom_types.append(atom_name.name[0])
        atom_types = np.array([atom_dict[atom_type] for atom_type in atom_types])

        # Make the backbone atoms distingushiable
        atom_types[[4, 6, 8, 14, 16]] = np.arange(4, 9)

    atom_dict_aligning = {"C": 0, "H": 1, "N": 2, "O": 3}
    atom_types_for_aligning = []
    for atom_name in ala_traj.topology.atoms:
        atom_types_for_aligning.append(atom_name.name[0])
    atom_types_for_aligning = np.array(
        [atom_dict_aligning[atom_type] for atom_type in atom_types_for_aligning]
    )

    h_initial = torch.nn.functional.one_hot(torch.tensor(atom_types))

    scaling = 10

    # now set up a prior
    prior = MeanFreeNormalDistribution(dim, n_particles, two_event_dims=False).to(
        device
    )

    print("Building Flow")
    # Build the Boltzmann Generator
    net_dynamics = EGNN_dynamics_AD2_cat(
        n_particles=n_particles,
        device=device,
        n_dimension=dim // n_particles,
        h_initial=h_initial,
        hidden_nf=64,
        act_fn=torch.nn.SiLU(),
        n_layers=5,
        recurrent=True,
        tanh=True,
        attention=True,
        condition_time=True,
        mode="egnn_dynamics",
        agg="sum",
    )

    bb_dynamics = BlackBoxDynamics(
        dynamics_function=net_dynamics, divergence_estimator=BruteForceEstimatorFast()
    )
    flow = DiffEqFlow(dynamics=bb_dynamics)
    # having a flow and a prior, we can now define a Boltzmann Generator

    bg = BoltzmannGenerator(prior, flow, target.to(device))
    if path is not None:
        print(f"Loading state_dict from {path}")
        load_weights(bg, path)

    atom_dict = {"H": 0, "C": 1, "N": 2, "O": 3}
    atom_types_xtb = []
    for atom_name in ala_traj.topology.atoms:
        atom_types_xtb.append(atom_name.name[0])
    atom_types_xtb = np.array([atom_dict[atom_type] for atom_type in atom_types_xtb])

    topology = dataset.system.mdtraj_topology
    adj_list = torch.from_numpy(
        np.array(
            [(b.atom1.index, b.atom2.index) for b in topology.bonds], dtype=np.int32
        )
    )
    temperature = 300
    number_dict = {0: 1, 1: 6, 2: 7, 3: 8}
    numbers = np.array([number_dict[atom_type] for atom_type in atom_types_xtb])

    def align_topology(sample, reference, scaling=scaling):
        sample = sample.reshape(-1, 3)
        all_dists = scipy.spatial.distance.cdist(sample, sample)
        adj_list_computed = create_adjacency_list(
            all_dists / scaling, atom_types_for_aligning
        )
        G_reference = nx.Graph(reference)
        G_sample = nx.Graph(adj_list_computed)
        # not same number of nodes
        if len(G_sample.nodes) != len(G_reference.nodes):
            return sample, False
        for i, atom_type in enumerate(atom_types_for_aligning):
            G_reference.nodes[i]["type"] = atom_type
            G_sample.nodes[i]["type"] = atom_type

        nm = iso.categorical_node_match("type", -1)
        GM = isomorphism.GraphMatcher(G_reference, G_sample, node_match=nm)
        is_isomorphic = GM.is_isomorphic()
        # True
        GM.mapping
        initial_idx = list(GM.mapping.keys())
        final_idx = list(GM.mapping.values())
        sample[initial_idx] = sample[final_idx]
        return sample, is_isomorphic

    target_xtb = XTBEnergy(
        XTBBridge(numbers=numbers, temperature=temperature, solvent="water"),
        two_event_dims=False,
    )

    latent_np = np.empty(shape=(0))
    samples_np = np.empty(shape=(0))
    log_w_np = np.empty(shape=(0))

    energies_np_xtb = np.empty(shape=(0))
    dlogp_np = np.empty(shape=(0))
    distances_x_np = np.empty(shape=(0))
    results = {}

    if sample_q:
        energy_offset = 34600
        results["Q samples"] = batch_size * n_sample_batches
        for i in tqdm.tqdm(range(n_sample_batches)):
            with torch.no_grad():
                samples, latent, dlogp = bg.sample(
                    batch_size, with_latent=True, with_dlogp=True
                )
                if not xtb:
                    aligned_mask = []
                    aligned_samples = []
                    for sample in samples:
                        aligned_sample, is_isomorphic = align_topology(
                            assert_numpy(sample), as_numpy(adj_list).tolist()
                        )
                        aligned_mask.append(is_isomorphic)
                        if is_isomorphic:
                            aligned_samples.append(torch.Tensor(aligned_sample))

                    if not sum(aligned_mask):
                        print(f"None out of {samples.shape[0]} could be aligned!!!")
                        continue
                    aligned_samples = torch.stack(aligned_samples) / scaling
                    classical_energies = target.energy(
                        aligned_samples.reshape(-1, n_particles * n_dimensions).to(
                            device
                        )
                    )
                    log_weights_classical = assert_numpy(
                        -classical_energies.reshape(-1, 1)
                        + prior.energy(latent.to(device))[aligned_mask]
                        + dlogp.reshape(-1, 1)[aligned_mask]
                    )
                    log_w_np = np.append(log_w_np, log_weights_classical)
                latent_np = np.append(latent_np, latent.detach().cpu().numpy())
                samples_np = np.append(samples_np, samples.detach().cpu().numpy())
                distances_x = (
                    distances_from_vectors(
                        distance_vectors(samples.view(-1, n_particles, n_dimensions))
                    )
                    .detach()
                    .cpu()
                    .numpy()
                    .reshape(-1)
                )
                distances_x_np = np.append(distances_x_np, distances_x)

                if xtb:
                    energies = as_numpy(
                        target_xtb.energy(samples.detach().cpu() / scaling)
                    )
                    energies_np_xtb = np.append(energies_np_xtb, energies)
                dlogp_np = np.append(dlogp_np, as_numpy(dlogp))

        latent_np = latent_np.reshape(-1, dim)
        samples_np = samples_np.reshape(-1, dim)

        latent_energies = prior.energy(torch.from_numpy(latent_np).cuda())
        if xtb:
            energies_np_xtb += energy_offset

            log_w_np_xtb = (
                -as_numpy(energies_np_xtb).reshape(-1, 1)
                + as_numpy(latent_energies)
                + dlogp_np.reshape(-1, 1)
            )
            energies_np_xtb = energies_np_xtb[np.isfinite(log_w_np_xtb).flatten()]
            log_w_np_xtb = log_w_np_xtb[np.isfinite(log_w_np_xtb)]
            ess_q = sampling_efficiency(torch.from_numpy(log_w_np_xtb)).item()
            results["KL(Q|P) - c"] = float(-np.mean(log_w_np_xtb))

        else:
            ess_q = sampling_efficiency(torch.from_numpy(log_w_np)).item()
            print(f"Estimating ess on {len(log_w_np)} samples")
            results["aligned samples"] = len(log_w_np)
            results["KL(Q|P) - c"] = float(-np.mean(log_w_np))
        print(f"ESS-Q{ess_q} XTB:{xtb}")
        results["ess_q"] = float(ess_q)

    if sample_p:
        if xtb:
            data_holdout = (
                torch.from_numpy(np.load(f"{data_path}/AD2_relaxed_holdout.npy"))
                .reshape(-1, 66)
                .to(torch.float32)
            )
        else:
            data_holdout = torch.from_numpy(dataset.xyz[::10].reshape(-1, 66)) * scaling
            print(f"Using {data_holdout.shape} data")
        bg_nrjs_data = []
        target_nrjs_data = []
        with torch.no_grad():
            for i in tqdm.tqdm(range(0, data_holdout.shape[0], batch_size)):
                bg_nrjs_data.append(
                    assert_numpy(bg.energy(data_holdout[i : i + batch_size].to(device)))
                )
                if xtb:
                    target_nrjs_data.append(
                        assert_numpy(
                            target_xtb.energy(
                                data_holdout[i : i + batch_size] / scaling
                            )
                        )
                    )
                else:
                    target_nrjs_data.append(
                        assert_numpy(
                            target.energy(
                                data_holdout[i : i + batch_size].to(device) / scaling
                            )
                        )
                    )

        bg_energies_data = np.concatenate(bg_nrjs_data)
        target_energies_data = np.concatenate(target_nrjs_data)
        log_w_tilde_p = bg_energies_data - target_energies_data
        ess_p = compute_foward_ess(torch.from_numpy(log_w_tilde_p))
        print(f"ESS-P {ess_p}")
        results["ess_p"] = float(ess_p)
        results["KL(P|Q) + c"] = float(log_w_tilde_p.mean())
        results["-log Q"] = float(np.mean(bg_energies_data))
    print(results)
    json.dump(results, open(f"{path}-results-XTB{xtb}{id}.json", "w"))


if __name__ == "__main__":

    main()
